#lang racket/base
(require "../common/set.rkt"
         "../common/struct-star.rkt"
         "../syntax/syntax.rkt"
         "../syntax/scope.rkt"
         "../syntax/binding.rkt"
         "../common/phase.rkt"
         "lift-key.rkt")

(provide (struct*-out root-expand-context)
         make-root-expand-context

         apply-post-expansion
         post-expansion-scope

         root-expand-context-encode-for-module
         root-expand-context-decode-for-module)

;; A `root-expand-context` is a subset of `expand-context` that is
;; preserved from a module's expansion for later use in a namespace
;; generated by `module->namespace` --- or preserved across different
;; expansions at the top level
(struct* root-expand-context
         (self-mpi        ; MPI for the enclosing module during compilation
          module-scopes   ; list of scopes for enclosing module or top level; includes next two fields
          * post-expansion  ; #f, a shifted multiscope to push to every expansion (often module's inside edge),
          ;                   a pair of a sms and a list of shifts, or a procedure (when not at the actual
          ;                   root, because an actual root needs to be marshalable)
          top-level-bind-scope  ; #f or a scope to constrain expansion bindings; see "expand-bind-top.rkt"
          all-scopes-stx  ; scopes like the initial import, which correspond to original forms
          * use-site-scopes ; #f or boxed list: scopes that should be pruned from binders
          defined-syms    ; phase -> sym -> id; symbols picked for bindings
          * frame-id      ; #f or a gensym to identify a binding frame; 'all matches any for use-site scopes
          counter         ; box of an integer; used for generating names deterministically
          lift-key        ; identifies (via `syntax-local-lift-context`) a target for lifts
          )) ; after adding a field, update `copy-module-context` in "context.rkt"

(define (make-root-expand-context #:self-mpi self-mpi
                                  #:initial-scopes [initial-scopes null]
                                  #:outside-scope [outside-scope top-level-common-scope]
                                  #:post-expansion-scope [post-expansion-scope (new-multi-scope 'top-level)]
                                  #:all-scopes-stx [all-scopes-stx #f])
  (define module-scopes (list* post-expansion-scope
                               outside-scope
                               initial-scopes))
  (root-expand-context self-mpi
                       module-scopes
                       post-expansion-scope ; post-expansion
                       (new-scope 'module) ; top-level-bind-scope
                       (or all-scopes-stx
                           (add-scopes empty-syntax module-scopes))
                       (box null)      ; use-site-scopes
                       (make-hasheqv)  ; defined-syms
                       (string->uninterned-symbol "root-frame") ; frame-id
                       (box 0)         ; counter
                       (generate-lift-key)))

;; ----------------------------------------

(define (apply-post-expansion pe s)
  (cond
    [(not pe) s]
    [(shifted-multi-scope? pe) (push-scope s pe)]
    [(pair? pe) (syntax-add-shifts (push-scope s (car pe)) (cdr pe))]
    [else (pe s)]))

(define (post-expansion-scope pe)
  (cond
    [(shifted-multi-scope? pe) pe]
    [(pair? pe) (car pe)]
    [else (error 'post-expansion-scope "internal error: cannot extract scope from ~s" pe)]))

;; ----------------------------------------

;; Encode information in a syntax object that can be serialized and deserialized
(define (root-expand-context-encode-for-module ctx orig-self new-self)
  (datum->syntax
   #f
   (vector (add-scopes empty-syntax (root-expand-context-module-scopes ctx))
           (apply-post-expansion (root-expand-context-post-expansion ctx) empty-syntax)
           (syntax-module-path-index-shift (root-expand-context-all-scopes-stx ctx) orig-self new-self)
           (add-scopes empty-syntax (unbox (root-expand-context-use-site-scopes ctx)))
           (for/hasheqv ([(phase ht) (in-hash (root-expand-context-defined-syms ctx))]) ; make immutable
             (values phase ht))
           (root-expand-context-frame-id ctx)
           (unbox (root-expand-context-counter ctx)))))

;; Decode the value produced by `root-expand-context-encode-for-module`
(define (root-expand-context-decode-for-module vec-s self)
  (define vec (and (syntax? vec-s) (syntax-e vec-s)))
  (unless (and (vector? vec)
               (= (vector-length vec) 7)
               (syntax? (vector-ref vec 0))
               (syntax-with-one-scope? (vector-ref vec 1))
               (syntax? (vector-ref vec 2))
               (syntax? (vector-ref vec 3))
               (defined-syms-hash? (syntax-e (vector-ref vec 4)))
               (symbol? (syntax-e (vector-ref vec 5)))
               (exact-nonnegative-integer? (syntax-e (vector-ref vec 6))))
    (error 'root-expand-context-decode-for-module
           "bad encoding: ~s"
           vec-s))
  (root-expand-context self
                       (extract-scope-list (vector-ref vec 0)) ; module-scopes
                       (cons (extract-scope (vector-ref vec 1))
                             (extract-shifts (vector-ref vec 1))) ; post-expansion
                       (new-scope 'module)                     ; top-level-bind-scope
                       (vector-ref vec 2)                      ; all-scopes-stx
                       (box (extract-scope-list (vector-ref vec 3))) ; use-site-scopes
                       (unpack-defined-syms (vector-ref vec 4)) ; defined-syms
                       (syntax-e (vector-ref vec 5))           ; frame-id
                       (box (syntax-e (vector-ref vec 6)))     ; counter
                       (generate-lift-key)))

(define (defined-syms-hash? v)
  (and (for/and ([(phase ht-s) (in-hash v)])
         (and (phase? phase)
              (hash? (syntax-e ht-s))
              (for/and ([(sym id) (in-hash (syntax-e ht-s))])
                (and (symbol? sym)
                     (identifier? id)))))))

(define (extract-scope-list stx)
  (map generalize-scope (set->list (syntax-scope-set stx 0))))

(define (syntax-with-one-scope? stx)
  (and (syntax? stx)
       (= 1 (set-count (syntax-scope-set stx 0)))))

(define (extract-scope stx)
  (define s (syntax-scope-set stx 0))
  (generalize-scope (set-first s)))

(define (extract-shifts stx)
  (syntax-mpi-shifts stx))

(define (unpack-defined-syms v)
  (hash-copy ; make mutable
   (for/hasheqv ([(phase ht-s) (in-hash (syntax-e v))])
     (values phase
             (hash-copy ; make mutable
              (for/hash ([(sym id) (in-hash (syntax-e ht-s))])
                (values sym id)))))))
